// AuthenticatedHttpServlet.java - A base class for authenticating servlets.
//
// Copyright (C) 1999-2002  Smart Software Consulting
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//
// Smart Software Consulting
// 1688 Silverwood Court
// Danville, CA  94526-3079
// USA
//
// http://www.smartsc.com
//

package com.smartsc.http;

import java.io.IOException;
import java.io.PrintWriter;

import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.smartsc.util.Base64;

public abstract class AuthenticatedHttpServlet extends HttpServlet
{
	protected boolean authorized( HttpServletRequest req)
	{
		boolean authorized = false;

		// Get authorization info from request header
		String authEncoded = req.getHeader( "Authorization");

		// If authorization found and using "Basic" scheme
		if( (authEncoded != null)
		&&  (authEncoded.startsWith( "Basic ")) )
		{
			// Decode authorization
			String authDecoded = Base64.decode( authEncoded.substring( 6));

			// Look for ':'
			int colon = authDecoded.indexOf( ':');
			if( colon > -1)
			{
				String username = authDecoded.substring( 0, colon);
				String password = authDecoded.substring( colon + 1);
				authorized = authorized( req, username, password);
			}
		}

		return authorized;
	}

	// doPost method
	public final void doPost( HttpServletRequest req, HttpServletResponse res)
		throws IOException
	{
		// Bail out if unauthorized
		if( !authorized( req))
		{
			sendUnauthorized( req, res);
			return;
		}

		// Do authorized post
		doAuthorizedPost( req, res);
	}

	// doGet method
	public final void doGet (HttpServletRequest req, HttpServletResponse res)
	throws IOException
	{
		// Bail out if unauthorized
		if( !authorized( req))
		{
			sendUnauthorized( req, res);
			return;
		}

		// Do authorized get
		doAuthorizedGet( req, res);
	}

	private final void sendUnauthorized(
		HttpServletRequest req, HttpServletResponse res)
	throws IOException
	{
		res.setStatus( HttpServletResponse.SC_UNAUTHORIZED);
		res.setHeader(
			"WWW-authenticate", "Basic realm=\"" + realmName( req) + "\"");
		res.setContentType("text/html");

		PrintWriter pw = res.getWriter();
		pw.println( "<html>");
		pw.println( "<head><title>");
		pw.println( "Unauthorized");
		pw.println( "</title></head>");
		pw.println( "<body>");
		pw.println( "<H1>Unauthorized</H1>");
		pw.print  ( "Proper authorization is required for <b>");
		pw.print  ( realmName( req));
		pw.println( "</b>.");
		pw.println( "Either your browser does not perform authorization,");
		pw.println( "or your authorization has failed.");
		pw.println( "</body></html>");
	}

	protected abstract boolean authorized(
		HttpServletRequest req, String username, String password);

	protected abstract String realmName( HttpServletRequest req);

	protected abstract void doAuthorizedGet(
		HttpServletRequest req, HttpServletResponse res)
	throws IOException;

	protected abstract void doAuthorizedPost(
		HttpServletRequest req, HttpServletResponse res)
	throws IOException;
}
